
"""
@author: Zhenjiao Jiang <zhenjiao.jiang@csiro.au>, October 2018.

"""

import numpy as np
import tensorflow as tf

z_size     = 20                     # size of latent z

flt_size   = 2
flt_size2  = 3

outsize_d  = 128                    # base outsize (width) of first layer in decoder 
outsize_e  = 64                     # base outsize (width) of first layer in encoder
core_size  = 4
leak_value = 0.2

def decoder(z, batch_size_new, phase_train=True, reuse=False, training=False):

    strides2    = [1,2,2,2,1]
    strides1    = [1,1,1,1,1]
    
    xavier_init = tf.contrib.layers.xavier_initializer()
    zero_init = tf.zeros_initializer()

    with tf.variable_scope("decoder_layer1",reuse=reuse):
        wd0 = tf.get_variable("wd0", shape=[z_size, core_size**3], initializer=xavier_init)         
        bd0 = tf.get_variable("bd0", shape=[core_size**3], initializer=zero_init) 
        g_ = tf.add(tf.matmul(z, wd0), bd0)
        g_0=tf.reshape(g_,[batch_size_new,core_size,core_size,core_size,1])    # outsize = 4 
     
    with tf.variable_scope("decoder_layer2",reuse=reuse): 
        wd1 = tf.get_variable("wd1", shape=[flt_size, flt_size, flt_size, outsize_d, 1], initializer=xavier_init)        
        bd1 = tf.get_variable("bd1", shape=[outsize_d], initializer=zero_init)        
        g_1 = tf.nn.conv3d_transpose(g_0, wd1, (batch_size_new,core_size*2,core_size*2,core_size*2,outsize_d), strides=strides2, padding="SAME")
        g_1 = tf.nn.bias_add(g_1, bd1)
        g_1 = tf.contrib.layers.batch_norm(g_1, is_training=phase_train, trainable=training)
        g_1 = lrelu(g_1, leak_value)  
        
    with tf.variable_scope("decoder_layer3",reuse=reuse):  
        wd2 = tf.get_variable("wd2", shape=[flt_size2, flt_size2, flt_size2, outsize_d, outsize_d], initializer=xavier_init) 
        bd2 = tf.get_variable("bd2", shape=[outsize_d], initializer=zero_init)         
        g_2 = tf.nn.conv3d(g_1, wd2, strides=strides1, padding="SAME")     
        g_2 = tf.nn.bias_add(g_2, bd2)
        g_2 = lrelu(g_2, leak_value)    
        
    with tf.variable_scope("decoder_layer4",reuse=reuse):        
        wd3 = tf.get_variable("wd3", shape=[flt_size, flt_size, flt_size, int(outsize_d/2), int(outsize_d)], initializer=xavier_init)        
        bd3 = tf.get_variable("bd3", shape=[int(outsize_d/2)], initializer=zero_init)
        g_3 = tf.nn.conv3d_transpose(g_2, wd3, (batch_size_new,core_size*4,core_size*4,core_size*4,int(outsize_d/2)), strides=strides2, padding="SAME")
        g_3 = tf.nn.bias_add(g_3, bd3)
        g_3 = tf.contrib.layers.batch_norm(g_3, is_training=phase_train, trainable=training)
        g_3 = lrelu(g_3, leak_value) 


    with tf.variable_scope("decoder_layer5",reuse=reuse):  
        wd4 = tf.get_variable("wd4", shape=[flt_size2, flt_size2, flt_size2, int(outsize_d/2), int(outsize_d/2)], initializer=xavier_init) 
        bd4 = tf.get_variable("bd4", shape=[int(outsize_d/2)], initializer=zero_init)         
        g_4 = tf.nn.conv3d(g_3, wd4, strides=strides1, padding="SAME")     
        g_4 = tf.nn.bias_add(g_4, bd4)
        g_4 = lrelu(g_4, leak_value)    

         
    with tf.variable_scope("decoder_layer6",reuse=reuse):   
        wd5 = tf.get_variable("wd5", shape=[flt_size, flt_size, flt_size, int(outsize_d/4), int(outsize_d/2)], initializer=xavier_init)        
        bd5 = tf.get_variable("bd5", shape=[int(outsize_d/4)], initializer=zero_init)        
        g_5 = tf.nn.conv3d_transpose(g_4, wd5, (batch_size_new,core_size*8,core_size*8,core_size*8,int(outsize_d/4)), strides=strides2, padding="SAME")
        g_5 = tf.nn.bias_add(g_5, bd5)
        g_5 = tf.contrib.layers.batch_norm(g_5, is_training=phase_train, trainable=training)
        g_5 = lrelu(g_5, leak_value)             
        
        
    with tf.variable_scope("decoder_layer7",reuse=reuse):  
        wd6 = tf.get_variable("wd6", shape=[flt_size2, flt_size2, flt_size2, int(outsize_d/4), int(outsize_d/4)], initializer=xavier_init) 
        bd6 = tf.get_variable("bd6", shape=[int(outsize_d/4)], initializer=zero_init)         
        g_6 = tf.nn.conv3d(g_5, wd6, strides=strides1, padding="SAME")     
        g_6 = tf.nn.bias_add(g_6, bd6)
        g_6 = lrelu(g_6, leak_value)            
        
        
    with tf.variable_scope("decoder_layer8",reuse=reuse):        
        wd7 = tf.get_variable("wd7", shape=[flt_size, flt_size, flt_size, int(outsize_d/8), int(outsize_d/4)], initializer=xavier_init)        
        bd7 = tf.get_variable("bd7", shape=[int(outsize_d/8)], initializer=zero_init)        
        g_7 = tf.nn.conv3d_transpose(g_6, wd7, (batch_size_new,core_size*16,core_size*16,core_size*16,int(outsize_d/8)), strides=strides2, padding="SAME")
        g_7 = tf.nn.bias_add(g_7, bd7)
        g_7 = tf.contrib.layers.batch_norm(g_7, is_training=phase_train, trainable=training)    #64
        g_7 = lrelu(g_7, leak_value)        


    with tf.variable_scope("decoder_layer9",reuse=reuse):  
        wd8 = tf.get_variable("wd8", shape=[flt_size, flt_size, flt_size, int(outsize_d/8), 1], initializer=xavier_init) 
        bd8 = tf.get_variable("bd8", shape=[1], initializer=zero_init)         
        g_8 = tf.nn.conv3d(g_7, wd8, strides=strides1, padding="SAME")     
        g_8 = tf.nn.bias_add(g_8, bd8)
        g_8 = tf.nn.sigmoid(g_8)

    return g_8

def lrelu(x, leak=0.2):
    return tf.maximum(x, leak*x)

def encoder(inputs, phase_train=True, reuse=False,training=False):
    
    strides2    = [1,2,2,2,1]
    
    xavier_init = tf.contrib.layers.xavier_initializer()
    zero_init = tf.zeros_initializer()
    
    with tf.variable_scope("encoder_layer1",reuse=reuse):
        we1 = tf.get_variable("we1", shape=[flt_size, flt_size, flt_size, 1, outsize_e], initializer=xavier_init)
        be1 = tf.get_variable("be1", shape=[outsize_e], initializer=zero_init)
        d_1 = tf.nn.conv3d(inputs, we1, strides=strides2, padding="SAME")
        d_1 = tf.nn.bias_add(d_1, be1)
        d_1 = lrelu(d_1, leak_value)
        
    with tf.variable_scope("encoder_layer2",reuse=reuse):        
        we2 = tf.get_variable("we2", shape=[flt_size, flt_size, flt_size, outsize_e, int(outsize_e/2)], initializer=xavier_init)
        be2 = tf.get_variable("be2", shape=[int(outsize_e/2)], initializer=zero_init)
        d_2 = tf.nn.conv3d(d_1, we2, strides=strides2, padding="SAME") 
        d_2 = tf.nn.bias_add(d_2, be2)
        d_2 = lrelu(d_2, leak_value)

    with tf.variable_scope("encoder_layer3",reuse=reuse):          
        we3 = tf.get_variable("we3", shape=[flt_size, flt_size, flt_size, int(outsize_e/2), int(outsize_e/4)], initializer=xavier_init)
        be3 = tf.get_variable("be3", shape=[int(outsize_e/4)], initializer=zero_init)
        d_3 = tf.nn.conv3d(d_2, we3, strides=strides2, padding="SAME")  
        d_3 = tf.nn.bias_add(d_3, be3)
        d_3 = lrelu(d_3, leak_value) 
        
    with tf.variable_scope("encoder_layer4",reuse=reuse):          
        we4 = tf.get_variable("we4", shape=[flt_size, flt_size, flt_size, int(outsize_e/4), 1], initializer=xavier_init)   
        be4 = tf.get_variable("be4", shape=[1], initializer=zero_init)  
        d_4 = tf.nn.conv3d(d_3, we4, strides=strides2, padding="SAME")     
        d_4 = tf.nn.bias_add(d_4, be4)
        d_4 = lrelu(d_4)
        
    with tf.variable_scope("encoder_layer5",reuse=reuse):     
        we5 = tf.get_variable("we5", shape=[core_size**3, z_size ], initializer=xavier_init) 
        be5 = tf.get_variable("be5", shape=[z_size], initializer=zero_init)     
        shape = d_4.get_shape().as_list()
        dim = np.prod(shape[1:])
        d_5 = tf.reshape(d_4, shape=[-1, dim])
        d_5 = tf.add(tf.matmul(d_5, we5), be5)            

    with tf.variable_scope("encoder_out",reuse=reuse):   
        z_latent=d_5
        m=tf.nn.moments(z_latent,axes=[1])
        mean=m[0]
        stddev=m[1]

    return z_latent, mean, stddev 